import pandas as pd
import logging
import numpy as np
from copy import copy


class EssayPart(object):
    """
    Base Class overing methods shared between different GnGExperiments
    """
    lambda_most_common = lambda x: np.NaN if x.empty else x.value_counts().index[0]
    filterdic = lambda dic, filterValue: [key for key, value in dic.items() if value == filterValue]

    resample_how = {
        #'Area': 'mean',
        #'TrialCnt': EssayPart.lambda_most_common
    }

    dtypes = {
        'Group': 'category',
        'Session ID': 'category',
        'MouseID': 'category',
        'TrialCnt': 'int'}

    meta = {
        'MouseID': '????',
        'Session': '????',
        'Phase': '????',
        'Group': '????'}

    def toDateTime(self):
        try:
            self.data.index = pd.to_datetime(self.data.index.values)
        except Exception as e:
            logging.error('Setting index to datetime failed! %s', e)

    def toDateDelta(self):
        try:
            self.data.index = self.data.index - self.data.index[0]
        except Exception as e:
            logging.error('Setting index to timedelta failed! %s', e)

    def setdtype(self, dtypes={'MouseID': 'category', 'Session': 'category', 'Group': 'category', 'PeriodType': 'category'}):
        try:
            for k, v in dtypes.items():
                try:
                    self.data[k] = self.data[k].astype(v)
                except KeyError:
                    logging.warning('Cannot find column %s! Skipping!', k)

        except Exception as e:
            logging.error('Setting dtypes failed! %s', e)

    def resample(self, rule, fill_method='pad', how=None):
        try:
            parts = []
            new = copy(self)
            new.data = self.data.copy()
            new.toDateTime()
            if how is None:
                how = self.resample_how
            #Check columns
            howDic = {}
            for column, func in how.items():
                if column not in new.data.columns:
                    logging.warning('Column %d not present in data. Will skip it!', column)
                else:
                    howDic[column] = func

            #new.data = new.data.resample(rule, how=howDic, fill_method=fill_method)
            new.data = new.data.resample(rule).agg(howDic).fillna(method='pad')
            new.toDateDelta()
            new.setdtype(self.dtypes)
            return new
        except Exception as e:
            logging.error('Resampling failed! %s', e)

    def roundIndex(self, precision=-8):
        try:
            self.toDateDelta()
            self.data.set_index(pd.to_timedelta(np.round(self.data.index.astype(np.int64), precision)), inplace=True)
        except Exception as e:
            logging.error('Rounding Time index failed!')

    def createDiffColumn(self, column, periods=None):
        """
        For a specified column add three additional columns that represent the same data
        but marking its onset and offset.

        New columns names will have the suffix _Diff, _OnSet, _OffSet

        PS: this assumes the column to contain blocks of 0's and 1's
        """
        #Replace NaN with 0, and test if data in column is binary!
        data = self.data[column].replace(np.nan, 0).tolist()
        if set(np.unique(data)).issubset({0.0, 1.0}):
            dData = np.diff([0] + data)
            onSet = np.where(dData == 1)[0]
            if periods is not None:
                firstPeriodValue = self.data.iloc[periods[:, 3].tolist(), self.data.columns.get_loc(column)]
                onSet = np.append(onSet, periods[np.where(firstPeriodValue), 3])
            offSet = [x - 1 for x in np.where(dData == -1)[0]]

            self.data.insert(self.data.columns.tolist().index(column) + 1, column + '_Diff', 0)
            self.data.insert(self.data.columns.tolist().index(column) + 1, column + '_OnSet', 0)
            self.data.insert(self.data.columns.tolist().index(column) + 1, column + '_OffSet', 0)

            self.data.loc[onSet, column + '_Diff'] = 1
            self.data.loc[offSet, column + '_Diff'] = -1
            self.data.loc[onSet, column + '_OnSet'] = 1
            self.data.loc[offSet, column + '_OffSet'] = 1

            #Add the columns to the resample_how dictionary
            for cn in ['_Diff', '_OnSet', '_OffSet']:
                self.resample_how[column + cn] = 'max'
        else:
            logging.warn('createDiffColumn: Column "%s" is not binary and will be skipped!', column)

    @staticmethod
    def IT(o):
        if hasattr(o,'__iter__'):
            return o
        else:
            return [o]

    @staticmethod
    def is_number(o):
        try:
            float(o)
            return True
        except ValueError:
            return False

    @staticmethod
    def filter_nan(o):
        filtered = []
        for t in o:
            if EssayPart.is_number(t):
                t = float(t)
                if ~ np.isnan(t):
                    filtered.append(t)
            else:
                filtered.append(t)
        return filtered

    def autoDetectResampleMethods(self):
        """
        Generate a dictionary that specifies a aggregation function for each column based on its values
        Specifying mean for continues values, max for binary numbers, and most_common for discrete
        """
        try:
            how = {}
            for cn in self.data.columns:
                cfunc = 'mean'

                unique = self.data[cn].unique()
                unique = self.filter_nan(unique)
                if len(unique) <= 2:
                    if self.is_number(unique[0]):
                        cfunc = 'max'
                    else:
                        cfunc = EssayPart.lambda_most_common
                elif len(unique) < 10:
                    cfunc = EssayPart.lambda_most_common
                else:
                    cfunc = 'mean'

                how[cn] = cfunc
            return how
        except Exception as e:
            logging.error('Auto detecting Resample method failed! %s', e)

    @staticmethod
    def assignTimeIndex(indexTarget, indexAssign):
        """
        Take two index list, and find and assign for each element in indexAssign
        its closest value in indexTarget
        """
        N = len(indexTarget)
        cur = indexTarget[0]
        nexI = 1
        nex = indexTarget[nexI]

        assignment = []
        for  idx in indexAssign:
            cd = idx - cur
            cd = cd.value * cd.value

            nd = idx - nex
            nd = nd.value * nd.value

            if (cd) < (nd):
                assignment.append(cur)
            else:
                assignment.append(nex)
                cur = nex

                if (nexI+1) < N:
                    nexI = nexI + 1
                    nex = indexTarget[nexI]

        return assignment
